{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.contrib import slim\n",
    "import numpy as np\n",
    "import tqdm\n",
    "from tqdm.notebook import tqdm_notebook\n",
    "from matplotlib import pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 512\n",
    "LR_PRIMAL = 2e-5\n",
    "LR_DUAL = 1e-4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def get_data_samples(N):\n",
    "    data = tf.random_uniform([N], minval=0, maxval=4, dtype=tf.int32)\n",
    "    return data\n",
    "\n",
    "def encoder_func(x, eps):\n",
    "    net = tf.concat([x, eps], axis=-1)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "\n",
    "    z = slim.fully_connected(net, 2, activation_fn=None)\n",
    "\n",
    "    return z\n",
    "\n",
    "\n",
    "def decoder_func(z):\n",
    "    net = z\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)\n",
    "\n",
    "    xlogits = slim.fully_connected(net, 4, activation_fn=None)\n",
    "    return xlogits\n",
    "\n",
    "def discriminator_func(x, z):\n",
    "    net = tf.concat([x, z], axis=1)\n",
    "    net =  slim.fully_connected(net, 256, activation_fn=tf.nn.elu)\n",
    "    for i in range(5):\n",
    "        dnet = slim.fully_connected(net, 256, scope='fc_%d_r0' % (i+1))\n",
    "        net += slim.fully_connected(dnet, 256, activation_fn=None, scope='fc_%d_r1' % (i+1),\n",
    "                                    weights_initializer=tf.constant_initializer(0.))\n",
    "        net = tf.nn.elu(net) \n",
    "\n",
    "#     net =  slim.fully_connected(net, 512, activation_fn=tf.nn.elu)\n",
    "    net =  slim.fully_connected(net, 1, activation_fn=None)\n",
    "    net = tf.squeeze(net, axis=1)\n",
    "    net += tf.reduce_sum(tf.square(z), axis=1)\n",
    "    \n",
    "    return net\n",
    "\n",
    "def create_scatter(x_test_labels, eps_test, savepath=None):\n",
    "    plt.figure(figsize=(5,5), facecolor='w')\n",
    "\n",
    "    for i in range(4):\n",
    "        z_out = sess.run(z_inferred, feed_dict={x_real_labels: x_test_labels[i], eps: eps_test})\n",
    "        plt.scatter(z_out[:, 0], z_out[:, 1],  edgecolor='none', alpha=0.5)\n",
    "\n",
    "    plt.xlim(-3, 3); plt.ylim(-3.5, 3.5)\n",
    "\n",
    "    plt.axis('off')\n",
    "    if savepath:\n",
    "        plt.savefig(savepath, dpi=512)\n",
    "\n",
    "encoder = tf.make_template('encoder', encoder_func)\n",
    "decoder = tf.make_template('decoder', decoder_func)\n",
    "discriminator = tf.make_template('discriminator', discriminator_func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "eps = tf.random_normal([BATCH_SIZE, 64])\n",
    "x_real_labels = get_data_samples(BATCH_SIZE)\n",
    "x_real = tf.one_hot(x_real_labels, 4)\n",
    "z_sampled = tf.random_normal([BATCH_SIZE, 2])\n",
    "z_inferred = encoder(x_real, eps)\n",
    "x_reconstr_logits = decoder(z_inferred)\n",
    "\n",
    "Tjoint = discriminator(x_real, z_inferred)\n",
    "Tseperate = discriminator(x_real, z_sampled)\n",
    "\n",
    "reconstr_err = tf.reduce_sum(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(labels=x_real, logits=x_reconstr_logits),\n",
    "    axis=1\n",
    ")\n",
    "\n",
    "loss_primal = tf.reduce_mean(reconstr_err + Tjoint)\n",
    "loss_dual = tf.reduce_mean(\n",
    "    tf.nn.sigmoid_cross_entropy_with_logits(logits=Tjoint, labels=tf.ones_like(Tjoint))\n",
    "    + tf.nn.sigmoid_cross_entropy_with_logits(logits=Tseperate, labels=tf.zeros_like(Tseperate))\n",
    ")\n",
    "\n",
    "optimizer_primal = tf.train.AdamOptimizer(LR_PRIMAL)\n",
    "optimizer_dual = tf.train.AdamOptimizer(LR_DUAL)\n",
    "\n",
    "qvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"encoder\")\n",
    "pvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"decoder\")\n",
    "dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, \"discriminator\")\n",
    "\n",
    "train_op_primal = optimizer_primal.minimize(loss_primal, var_list=pvars+qvars)\n",
    "train_op_dual = optimizer_dual.minimize(loss_dual, var_list=dvars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "tags": []
   },
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=0.0, max=100000.0), HTML(value='')))",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "94392ba71e674f8b8ef84da4eba2b3d1"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "\n"
    },
    {
     "output_type": "error",
     "ename": "NameError",
     "evalue": "name 'loss_primal' is not defined",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-09df2623a3f4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mprogress\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m100000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprogress\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mELBO_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mloss_primal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_op_primal\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m     \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_op_dual\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0msess\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_op_dual\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'loss_primal' is not defined"
     ]
    }
   ],
   "source": [
    "x_test_labels = [[i] * BATCH_SIZE for i in range(4)]\n",
    "eps_test = np.random.randn(BATCH_SIZE, 64) \n",
    "\n",
    "outdir = './out_toy'\n",
    "if not os.path.exists(outdir):\n",
    "    os.makedirs(outdir)\n",
    "    \n",
    "progress = tqdm_notebook(range(100000))\n",
    "for i in progress:\n",
    "    ELBO_out, _ = sess.run([loss_primal, train_op_primal])\n",
    "    sess.run(train_op_dual)\n",
    "    sess.run(train_op_dual)\n",
    "\n",
    "    progress.set_description('ELBO = %.2f' % ELBO_out)\n",
    "    if i % 100 == 0:\n",
    "        create_scatter(x_test_labels, eps_test, savepath=os.path.join(outdir, '%08d.png' % i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.5 64-bit",
   "language": "python",
   "name": "python_defaultSpec_1596871389747"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "widgets": {
   "state": {
    "13d1dbbaf9a741d69739f52d03a9e4e1": {
     "views": [
      {
       "cell_index": 5
      }
     ]
    }
   },
   "version": "1.2.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}